--- redirect_from: - "/03/subsection/01-03/sos-notebook1" interact_link: content/03/subsection/01-03/sos_notebook1.ipynb kernel_name: sos kernel_path: content/03/subsection/01-03 has_widgets: false title: |- Fig 1. Accurate decoding of position with a RNN pagenum: 4 prev_page: url: /03/00.html next_page: url: /03/subsection/02-04-05/sos_notebook2.html suffix: .ipynb search: figure error python b function size mean median r github sos interactive figures author calculations written com plotly errors lines rnn results matplotlib prediction timewindow code example jupyter notebook script scripts vatlab io docs workflow reproduces paper ardi tampuu last raul vicente using fetched repo neurocsut ratgps ploting library location decoding based ca neural data recorded m square open field environment window shows blue represent decoder red bayesian approaches approach averaged different independent realizations training algorithm black dots depict individual model shown animal doi org journal pcbi g fig comment: "***PROGRAMMATICALLY GENERATED, DO NOT EDIT. SEE ORIGINAL FILES IN /content***" ---
Fig 1. Accurate decoding of position with a RNN

This code example is a Jupyter notebook with Script of Scripts (SoS) workflow. It reproduces interactive figures for the paper by first author Ardi Tampuu and last author Raul Vicente.

The calculations are written using Python 2.7 (from this fetched from repo), and the interactive figures are written in Python 3.6 with the ploting library Plotly.

Figure 1:

Location decoding errors based on CA1 neural data recorded from 1m square open field environment as a function of time window size. (a) shows mean error and (b) median error. Blue lines represent errors from the RNN decoder and red lines from Bayesian approaches. Results for the RNN approach are averaged over different independent realizations of the training algorithm. Black dots depict the mean/median error of each individual model. Results shown are for animal R2192.

https://doi.org/10.1371/journal.pcbi.1006822.g001

Python 2.7


Calculations for Fig 1:

%use Python2 
import matplotlib.style
import matplotlib as mpl
mpl.style.use('classic')
import numpy as np
from scipy.io import loadmat

#set up nicer color scheme
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),  
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),  
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),  
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),  
(188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
# Rescale to values between 0 and 1 
for i in range(len(tableau20)):  
    r, g, b = tableau20[i]  
    tableau20[i] = (r / 255., g / 255., b / 255.)
    
# this file contains the perfromance of 10 models on R2192 "grep"-ed from the actual log files
f=open("R2192_grepped_predictions.log")
lines = f.readlines()
f.close()

# fill a dictionary where keys are timewindow sizes, filled with [mean, median] for 10 models
RNN_stats={}
for line in lines:
    pieces=line.split(" ")
    win_size = pieces[0].split("x")[1]
    win_size = int(win_size[:win_size.find("_")])
    median = float(pieces[-1])
    mean = float(pieces[-5][:-1])
    if win_size in RNN_stats.keys():
        RNN_stats[win_size].append([mean, median])
    else:
        RNN_stats[win_size]=[[mean, median]]
# print RNN_stats[1400]

# Results with Bayes with flat prior (MLE)
# imported as a dictionary, each item in dictionary contains results for all 5 rats
# first rat is R2192 (index 0)
bay_field_dict = loadmat("Bayes_res/Fig1&3ab_decodingDataForOpenField.mat")
#bay_field_dict = loadmat("Bayes_res/2dDecodeAllRatsAllWindowsNaiveBayesNoTruncate.mat")
# print bay_field_dict.keys()
flat_medians=bay_field_dict['medianErr']
flat_means= bay_field_dict['meanErr']
flat_win = bay_field_dict['tWin2Test'].flatten()
# print flat_win.shape, flat_means[:,0].shape
# print "\n for Table 1: ", bay_field_dict["bstMean"]

# """ for Table 1:  [[ 15.82963073  16.06929415  17.86089428  18.81734775  17.03691594]
#  [  2.8          3.8          2.8          2.8          3.4       ]]"""

# Results with Bayes with flat prior (MLE)
# imported as a dictionary, each item in dictionary contains results for all 5 rats
# first rat is R2192 (index 0)

# bay_field_dict_history_h5 = loadmat("Bayes_res/oldParams_2dDecodeFullBayesWithHistorySigma1History5.mat") #old params
bay_field_dict_history = loadmat("Bayes_res/Fig1&3ab_2dDecodeFullBayesWithHistorySigma1History15.mat")
#bay_field_dict_history = loadmat("Bayes_res/2dDecodeAllRatsAllWindowsBayesWithHistorySpeed1History15NoTruncate.mat")
# print bay_field_dict_history.keys()
memory_medians = bay_field_dict_history['medianErr']
memory_means = bay_field_dict_history['meanErr']
memory_win = bay_field_dict_history['tWin2Test'].flatten()

# print memory_means[:,0]
# print memory_medians[:,0]
# print np.min(memory_medians, axis=0)

# print "\n for Table 1: ", bay_field_dict_history["bstMean"]
# """ for Table 1:  [[ 15.46168191  14.99576142  16.5269506   18.26098815  16.40828295]
#  [  2.           1.8          2.8          2.6          3.4       ]]"""

Figure 1.a, Matplotlib:

%use Python2
import matplotlib.pyplot as plt
import numpy as np

average = []
# draw individual model's performance as dots
plt.figure(figsize=(8,4.5))
for size in sorted(RNN_stats.keys()):
        means = np.array(RNN_stats[size])[:,0]
        plt.plot([size]*len(means), means, "o", color="black")
        average.append(np.mean(means))

# print average, "\n",flat_means[:,0]
a = np.array(memory_medians)

plt.xlabel("Time window size (ms)",fontsize=20)
plt.ylabel("Mean error (cm)",fontsize=20)

plt.xticks(fontsize=17)
plt.yticks(fontsize=18)

plt.xlim([175,4025])
plt.ylim([9,27])
plt.xticks(np.arange(200,4001,400))

plt.plot(flat_win*1000, flat_means[:,0], color=tableau20[6], linewidth=4,label="Bayesian with flat prior (MLE)", linestyle="dashed")
plt.plot(memory_win*1000, memory_means[:,0], color=tableau20[6], linewidth=4,label="Bayesian decoder with memory")
plt.plot(sorted(RNN_stats.keys()), average, color=tableau20[0], linewidth=4, label="RNN decoder")

plt.legend(fontsize=17)
plt.title("(a) Mean errors with different window size",fontsize=22,y=1.02)
plt.tight_layout()
plt.savefig("R2192_windows_mean_3models.png")
plt.show()

Figure 1.b, Matplotlib:

%use Python2
import matplotlib.pyplot as plt
average = []

plt.figure(figsize=(8,4.5))
for size in sorted(RNN_stats.keys()):
        medians = np.array(RNN_stats[size])[:,1]
        plt.plot([size]*len(medians), medians, "o", color="black")
        average.append(np.mean(medians))
# print average, "\n" 
# print flat_medians[:,0] 
# print memory_medians[:,0]

#plt.title("Median prediction error for R2192 in function of timewindow size",fontsize=20)
plt.xlabel("Time window size (ms)",fontsize=20)
plt.ylabel("Median error (cm)",fontsize=20)

plt.xlim([175,4025])
plt.ylim([9,27])
plt.xlim([175,4025])
plt.xticks(fontsize=17)
plt.yticks(fontsize=18)
plt.xticks(np.arange(200,4001,400))

plt.plot(flat_win*1000, flat_medians[:,0], color=tableau20[6], linewidth=4,label="Bayesian with flat prior (MLE)",linestyle="dashed")
plt.plot(memory_win*1000, memory_medians[:,0], color=tableau20[6], linewidth=4,label="Bayesian decoder with memory")
plt.plot(sorted(RNN_stats.keys()), average, color=tableau20[0], linewidth=4,label="RNN decoder")

plt.legend(fontsize=17)
plt.title("(b) Median errors with different window size",fontsize=22,y=1.02)
plt.tight_layout()
plt.savefig("R2192_windows_median_3models.png")
plt.show()
%use Python2
import pickle

with open('train.pickle', 'wb') as f:
    pickle.dump([size, means, RNN_stats, flat_win, flat_means, flat_medians, memory_medians, tableau20, memory_win, memory_means, average], f)

Python 3.7


(a) Mean prediction error for R2192 in function of timewindow size

Figure 1.a:

%use Python3
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
import numpy as np
config={'showLink': False, 'displayModeBar': False}
 
with open('train.pickle', 'rb') as f:
    size, means, RNN_stats, flat_win, flat_means, flat_medians, memory_medians, tableau20, memory_win, memory_means, average = pickle.load(f, encoding='bytes')
    
    
figa = go.Figure()

average = []
# draw individual model's performance as dots
for size in sorted(RNN_stats.keys()):
        means = np.array(RNN_stats[size])[:,0]
        figa.add_trace(go.Scatter(x = [size]*len(means), 
                                 y = means,
                                 mode = 'markers',
                                 showlegend=False, 
                                 line = dict(color="black")))
        average.append(np.mean(means))

figa.add_trace(go.Scatter(x = flat_win*1000, 
                         y = flat_means[:,0], 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4,
                                   dash="dash"),
                         name = "Bayesian with flat prior (MLE)"))

figa.add_trace(go.Scatter(x = memory_win*1000, 
                         y = memory_means[:,0], 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4),
                         name = "Bayesian decoder with memory"))

figa.add_trace(go.Scatter(x = sorted(RNN_stats.keys()), 
                         y = average,
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[0]),
                                   width=4),
                         name = "RNN decoder"))


figa.update_layout(title = '(a) Mean errors with different window size',
                  title_x = 0.5, 
                  xaxis_title='Time window size (ms)',
                  xaxis=dict(range=[175,4025], 
                             mirror=True,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',
                             tickfont = dict(size=20)),
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[9,27], 
                             mirror=True,
                             ticks='outside', 
                             showline=True, 
                             linecolor='#000',
                             tickfont = dict(size=20)),
                  legend=dict(x=0.45, 
                              y=0.8,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 800, 
                  height = 480,
                  font = dict(size = 17),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=35))

plot(figa, filename = 'fig1_a.html', config = config)
# THEBELAB
display(HTML('fig1_a.html'))
# BINDER
# iplot(figa,config=config)

(b) Median prediction error for R2192 in function of timewindow size

Figure 1.b:

%use Python3
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
import numpy as np
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    size, means, RNN_stats, flat_win, flat_means, flat_medians, memory_medians, tableau20, memory_win, memory_means, average = pickle.load(f, encoding='bytes')


figb = go.Figure()

average = []
# draw individual model's performance as dots
for size in sorted(RNN_stats.keys()):
        medians = np.array(RNN_stats[size])[:,1]
        figb.add_trace(go.Scatter(x = [size]*len(medians), 
                                 y = medians,
                                 mode = 'markers',
                                 showlegend=False, 
                                 line = dict(color="black")))
        average.append(np.mean(medians))

figb.add_trace(go.Scatter(x = flat_win*1000, 
                         y = flat_medians[:,0], 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4,
                                   dash="dash"),
                         name = "Bayesian with flat piror (MLE)"))

figb.add_trace(go.Scatter(x = memory_win*1000, 
                         y = memory_medians[:,0], 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4),
                         name = "Bayesian decoder with memory"))

figb.add_trace(go.Scatter(x = sorted(RNN_stats.keys()), 
                         y = average,
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[0]),
                                   width=4),
                         name = "RNN decoder"))


figb.update_layout(title = '(b) Median errors with different window size',
                  title_x = 0.5, 
                  xaxis_title='Time window size (ms)',
                  xaxis=dict(range=[175,4025], 
                             mirror=True,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',       
                             tickfont = dict(size=20)),
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[9,27], 
                             mirror=True,
                             ticks='outside', 
                             showline=True, 
                             linecolor='#000',
                             tickfont = dict(size=20)),
                  legend=dict(x=0.45, 
                              y=0.8,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 800, 
                  height = 480,
                  font = dict(size = 17),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=35))

plot(figb, filename = 'fig1_b.html', config = config)
# THEBELAB
display(HTML('fig1_b.html'))
# BINDER
# iplot(figb,config=config)